{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "speaking-algebra",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import cv2\n",
    "import os.path as osp\n",
    "import decord\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import urllib\n",
    "import moviepy.editor as mpy\n",
    "import random as rd\n",
    "from mmpose.apis import vis_pose_result\n",
    "from mmpose.models import TopDown\n",
    "from mmcv import load, dump\n",
    "\n",
    "# We assume the annotation is already prepared\n",
    "gym_train_ann_file = '../data/skeleton/gym_train.pkl'\n",
    "gym_val_ann_file = '../data/skeleton/gym_val.pkl'\n",
    "ntu60_xsub_train_ann_file = '../data/skeleton/ntu60_xsub_train.pkl'\n",
    "ntu60_xsub_val_ann_file = '../data/skeleton/ntu60_xsub_val.pkl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "alive-consolidation",
   "metadata": {},
   "outputs": [],
   "source": [
    "FONTFACE = cv2.FONT_HERSHEY_DUPLEX\n",
    "FONTSCALE = 0.6\n",
    "FONTCOLOR = (255, 255, 255)\n",
    "BGBLUE = (0, 119, 182)\n",
    "THICKNESS = 1\n",
    "LINETYPE = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ranging-conjunction",
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_label(frame, label, BGCOLOR=BGBLUE):\n",
    "    threshold = 30\n",
    "    def split_label(label):\n",
    "        label = label.split()\n",
    "        lines, cline = [], ''\n",
    "        for word in label:\n",
    "            if len(cline) + len(word) < threshold:\n",
    "                cline = cline + ' ' + word\n",
    "            else:\n",
    "                lines.append(cline)\n",
    "                cline = word\n",
    "        if cline != '':\n",
    "            lines += [cline]\n",
    "        return lines\n",
    "    \n",
    "    if len(label) > 30:\n",
    "        label = split_label(label)\n",
    "    else:\n",
    "        label = [label]\n",
    "    label = ['Action: '] + label\n",
    "    \n",
    "    sizes = []\n",
    "    for line in label:\n",
    "        sizes.append(cv2.getTextSize(line, FONTFACE, FONTSCALE, THICKNESS)[0])\n",
    "    box_width = max([x[0] for x in sizes]) + 10\n",
    "    text_height = sizes[0][1]\n",
    "    box_height = len(sizes) * (text_height + 6)\n",
    "    \n",
    "    cv2.rectangle(frame, (0, 0), (box_width, box_height), BGCOLOR, -1)\n",
    "    for i, line in enumerate(label):\n",
    "        location = (5, (text_height + 6) * i + text_height + 3)\n",
    "        cv2.putText(frame, line, location, FONTFACE, FONTSCALE, FONTCOLOR, THICKNESS, LINETYPE)\n",
    "    return frame\n",
    "    \n",
    "\n",
    "def vis_skeleton(vid_path, anno, category_name=None, ratio=0.5):\n",
    "    vid = decord.VideoReader(vid_path)\n",
    "    frames = [x.asnumpy() for x in vid]\n",
    "    \n",
    "    h, w, _ = frames[0].shape\n",
    "    new_shape = (int(w * ratio), int(h * ratio))\n",
    "    frames = [cv2.resize(f, new_shape) for f in frames]\n",
    "    \n",
    "    assert len(frames) == anno['total_frames']\n",
    "    # The shape is N x T x K x 3\n",
    "    kps = np.concatenate([anno['keypoint'], anno['keypoint_score'][..., None]], axis=-1)\n",
    "    kps[..., :2] *= ratio\n",
    "    # Convert to T x N x K x 3\n",
    "    kps = kps.transpose([1, 0, 2, 3])\n",
    "    vis_frames = []\n",
    "\n",
    "    # we need an instance of TopDown model, so build a minimal one\n",
    "    model = TopDown(backbone=dict(type='ShuffleNetV1'))\n",
    "\n",
    "    for f, kp in zip(frames, kps):\n",
    "        bbox = np.zeros([0, 4], dtype=np.float32)\n",
    "        result = [dict(bbox=bbox, keypoints=k) for k in kp]\n",
    "        vis_frame = vis_pose_result(model, f, result)\n",
    "        \n",
    "        if category_name is not None:\n",
    "            vis_frame = add_label(vis_frame, category_name)\n",
    "        \n",
    "        vis_frames.append(vis_frame)\n",
    "    return vis_frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "applied-humanity",
   "metadata": {},
   "outputs": [],
   "source": [
    "keypoint_pipeline = [\n",
    "    dict(type='PoseDecode'),\n",
    "    dict(type='PoseCompact', hw_ratio=1., allow_imgpad=True),\n",
    "    dict(type='Resize', scale=(-1, 64)),\n",
    "    dict(type='CenterCrop', crop_size=64),\n",
    "    dict(type='GeneratePoseTarget', sigma=0.6, use_score=True, with_kp=True, with_limb=False)\n",
    "]\n",
    "\n",
    "limb_pipeline = [\n",
    "    dict(type='PoseDecode'),\n",
    "    dict(type='PoseCompact', hw_ratio=1., allow_imgpad=True),\n",
    "    dict(type='Resize', scale=(-1, 64)),\n",
    "    dict(type='CenterCrop', crop_size=64),\n",
    "    dict(type='GeneratePoseTarget', sigma=0.6, use_score=True, with_kp=False, with_limb=True)\n",
    "]\n",
    "\n",
    "from mmaction.datasets.pipelines import Compose\n",
    "def get_pseudo_heatmap(anno, flag='keypoint'):\n",
    "    assert flag in ['keypoint', 'limb']\n",
    "    pipeline = Compose(keypoint_pipeline if flag == 'keypoint' else limb_pipeline)\n",
    "    return pipeline(anno)['imgs']\n",
    "\n",
    "def vis_heatmaps(heatmaps, channel=-1, ratio=8):\n",
    "    # if channel is -1, draw all keypoints / limbs on the same map\n",
    "    import matplotlib.cm as cm\n",
    "    h, w, _ = heatmaps[0].shape\n",
    "    newh, neww = int(h * ratio), int(w * ratio)\n",
    "    \n",
    "    if channel == -1:\n",
    "        heatmaps = [np.max(x, axis=-1) for x in heatmaps]\n",
    "    cmap = cm.viridis\n",
    "    heatmaps = [(cmap(x)[..., :3] * 255).astype(np.uint8) for x in heatmaps]\n",
    "    heatmaps = [cv2.resize(x, (neww, newh)) for x in heatmaps]\n",
    "    return heatmaps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "automatic-commons",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load GYM annotations\n",
    "lines = list(urllib.request.urlopen('https://sdolivia.github.io/FineGym/resources/dataset/gym99_categories.txt'))\n",
    "gym_categories = [x.decode().strip().split('; ')[-1] for x in lines]\n",
    "gym_annos = load(gym_train_ann_file) + load(gym_val_ann_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "numerous-bristol",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-04-25 22:18:53--  https://download.openmmlab.com/mmaction/posec3d/gym_samples.tar\n",
      "Resolving download.openmmlab.com (download.openmmlab.com)... 124.160.145.22\n",
      "Connecting to download.openmmlab.com (download.openmmlab.com)|124.160.145.22|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 36300800 (35M) [application/x-tar]\n",
      "Saving to: ‘gym_samples.tar’\n",
      "\n",
      "100%[======================================>] 36,300,800  11.5MB/s   in 3.0s   \n",
      "\n",
      "2021-04-25 22:18:58 (11.5 MB/s) - ‘gym_samples.tar’ saved [36300800/36300800]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# download sample videos of GYM\n",
    "!wget https://download.openmmlab.com/mmaction/posec3d/gym_samples.tar\n",
    "!tar -xf gym_samples.tar\n",
    "!rm gym_samples.tar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "ranging-harrison",
   "metadata": {},
   "outputs": [],
   "source": [
    "gym_root = 'gym_samples/'\n",
    "gym_vids = os.listdir(gym_root)\n",
    "# visualize pose of which video? index in 0 - 50.\n",
    "idx = 1\n",
    "vid = gym_vids[idx]\n",
    "\n",
    "frame_dir = vid.split('.')[0]\n",
    "vid_path = osp.join(gym_root, vid)\n",
    "anno = [x for x in gym_annos if x['frame_dir'] == frame_dir][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "fitting-courage",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize Skeleton\n",
    "vis_frames = vis_skeleton(vid_path, anno, gym_categories[anno['label']])\n",
    "vid = mpy.ImageSequenceClip(vis_frames, fps=24)\n",
    "vid.ipython_display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "orange-logging",
   "metadata": {},
   "outputs": [],
   "source": [
    "keypoint_heatmap = get_pseudo_heatmap(anno)\n",
    "keypoint_mapvis = vis_heatmaps(keypoint_heatmap)\n",
    "keypoint_mapvis = [add_label(f, gym_categories[anno['label']]) for f in keypoint_mapvis]\n",
    "vid = mpy.ImageSequenceClip(keypoint_mapvis, fps=24)\n",
    "vid.ipython_display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "residential-conjunction",
   "metadata": {},
   "outputs": [],
   "source": [
    "limb_heatmap = get_pseudo_heatmap(anno, 'limb')\n",
    "limb_mapvis = vis_heatmaps(limb_heatmap)\n",
    "limb_mapvis = [add_label(f, gym_categories[anno['label']]) for f in limb_mapvis]\n",
    "vid = mpy.ImageSequenceClip(limb_mapvis, fps=24)\n",
    "vid.ipython_display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "coupled-stranger",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The name list of \n",
    "ntu_categories = ['drink water', 'eat meal/snack', 'brushing teeth', 'brushing hair', 'drop', 'pickup', \n",
    "                  'throw', 'sitting down', 'standing up (from sitting position)', 'clapping', 'reading', \n",
    "                  'writing', 'tear up paper', 'wear jacket', 'take off jacket', 'wear a shoe', \n",
    "                  'take off a shoe', 'wear on glasses', 'take off glasses', 'put on a hat/cap', \n",
    "                  'take off a hat/cap', 'cheer up', 'hand waving', 'kicking something', \n",
    "                  'reach into pocket', 'hopping (one foot jumping)', 'jump up', \n",
    "                  'make a phone call/answer phone', 'playing with phone/tablet', 'typing on a keyboard', \n",
    "                  'pointing to something with finger', 'taking a selfie', 'check time (from watch)', \n",
    "                  'rub two hands together', 'nod head/bow', 'shake head', 'wipe face', 'salute', \n",
    "                  'put the palms together', 'cross hands in front (say stop)', 'sneeze/cough', \n",
    "                  'staggering', 'falling', 'touch head (headache)', 'touch chest (stomachache/heart pain)', \n",
    "                  'touch back (backache)', 'touch neck (neckache)', 'nausea or vomiting condition', \n",
    "                  'use a fan (with hand or paper)/feeling warm', 'punching/slapping other person', \n",
    "                  'kicking other person', 'pushing other person', 'pat on back of other person', \n",
    "                  'point finger at the other person', 'hugging other person', \n",
    "                  'giving something to other person', \"touch other person's pocket\", 'handshaking', \n",
    "                  'walking towards each other', 'walking apart from each other']\n",
    "ntu_annos = load(ntu60_xsub_train_ann_file) + load(ntu60_xsub_val_ann_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "critical-review",
   "metadata": {},
   "outputs": [],
   "source": [
    "ntu_root = 'ntu_samples/'\n",
    "ntu_vids = os.listdir(ntu_root)\n",
    "# visualize pose of which video? index in 0 - 50.\n",
    "idx = 20\n",
    "vid = ntu_vids[idx]\n",
    "\n",
    "frame_dir = vid.split('.')[0]\n",
    "vid_path = osp.join(ntu_root, vid)\n",
    "anno = [x for x in ntu_annos if x['frame_dir'] == frame_dir.split('_')[0]][0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "seasonal-palmer",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-04-25 22:21:16--  https://download.openmmlab.com/mmaction/posec3d/ntu_samples.tar\n",
      "Resolving download.openmmlab.com (download.openmmlab.com)... 124.160.145.22\n",
      "Connecting to download.openmmlab.com (download.openmmlab.com)|124.160.145.22|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 121753600 (116M) [application/x-tar]\n",
      "Saving to: ‘ntu_samples.tar’\n",
      "\n",
      "100%[======================================>] 121,753,600 14.4MB/s   in 9.2s   \n",
      "\n",
      "2021-04-25 22:21:26 (12.6 MB/s) - ‘ntu_samples.tar’ saved [121753600/121753600]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# download sample videos of NTU-60\n",
    "!wget https://download.openmmlab.com/mmaction/posec3d/ntu_samples.tar\n",
    "!tar -xf ntu_samples.tar\n",
    "!rm ntu_samples.tar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "accompanied-invitation",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_frames = vis_skeleton(vid_path, anno, ntu_categories[anno['label']])\n",
    "vid = mpy.ImageSequenceClip(vis_frames, fps=24)\n",
    "vid.ipython_display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "respiratory-conclusion",
   "metadata": {},
   "outputs": [],
   "source": [
    "keypoint_heatmap = get_pseudo_heatmap(anno)\n",
    "keypoint_mapvis = vis_heatmaps(keypoint_heatmap)\n",
    "keypoint_mapvis = [add_label(f, gym_categories[anno['label']]) for f in keypoint_mapvis]\n",
    "vid = mpy.ImageSequenceClip(keypoint_mapvis, fps=24)\n",
    "vid.ipython_display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "thirty-vancouver",
   "metadata": {},
   "outputs": [],
   "source": [
    "limb_heatmap = get_pseudo_heatmap(anno, 'limb')\n",
    "limb_mapvis = vis_heatmaps(limb_heatmap)\n",
    "limb_mapvis = [add_label(f, gym_categories[anno['label']]) for f in limb_mapvis]\n",
    "vid = mpy.ImageSequenceClip(limb_mapvis, fps=24)\n",
    "vid.ipython_display()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
